查看原文
其他

ICML 2024 | 图上的泛化挑战:从不变性到因果性

吴齐天 PaperWeekly
2024-08-23


©PaperWeekly 原创 · 作者 | 吴齐天
单位 | 上海交通大学博士生
研究方向 | 机器学习与图深度学习


图机器学习目前仍然是一个热门的研究方向,特别是在 AI4Science 的浪潮推动下,涌现出越来越多样化的图数据的应用场景。不同于一般的图像和文本数据,图(Graph)是一种数学抽象后的数据形式,用以描述一个系统中实体的属性和实体之间的相互作用关系。因此,图结构数据不仅可以描述不同尺度、不同规模的真实物理系统(如分子、蛋白质、社交网络等),也可以表达某种建模后的拓扑关系(如场景图、工业流程、思维链等)。
如何构建面向图数据的通用基础模型(Foundation Model)是近期备受关注的研究问题。尽管现有的方法,如图神经网络(GNNs)、Graph Transformer 等,已经展现出了强大的表征能力,但是机器学习模型在图结构数据上的泛化性(Generalization)依然是一个未被充分研究的开放问题 [1,2,3]

一方面,图数据涉及的非欧空间、拓扑几何关系会大大增加建模上的困难,使得现有的用于提升模型泛化性的方法很难胜任 [4,5,6]。另一方面,图数据的分布偏移(Distribution Shift),即训练与测试数据的分布差异,会来源于更加复杂的引导因素(如拓扑结构)和外部环境,使得这一问题的研究更具挑战 [7,8]

▲ 训练集和测试集的分布差异带来了泛化性挑战,图数据上的分布偏移可能来源于多个因素(如特征、标签、结构等)

本文重点关注图上的分布外泛化(Out-of-Distribution Generalization,简称 OOD)问题。首先会介绍 OOD 问题的定义和典型的场景,然后会展开介绍近期发表的三个工作:
  • Handling Distribution Shifts on Graphs: An Invariance Perspective(ICLR 2022)
  • Graph Out-of-Distribution Generalization via Causal Intervention(WWW 2024)
  • Learning Divergence Fields for Shift-Robust Graph Representations(ICML 2024)

他们从不同角度来分析和求解图上的 OOD 问题,其中第一个工作主要基于不变性原理,后两个工作主要利用因果干预。此外,本文还会讨论不同方法的局限性,以及未来可以进一步探索的方向。



研究动机和问题描述

1.1 开放环境下的分布偏移

泛化性问题之所以如此重要,正因为真实场景中的模型往往需要与一个开放、动态、复杂的环境进行交互。在实际场景中,由于观测与资源的有限性,训练数据无法包含所有可能出现的环境,模型也无法在训练过程中预知所有未来可能出现的情况。在测试阶段模型极可能会接触到与训练主体分布不一致的样本。研究机器学习模型在训练分布以外的测试数据上的表现,就是分布外泛化(OOD)问题所关心的重点。

▲ 几种含分布偏移的图数据场景,机器学习模型需要从有限的训练数据泛化到新的测试分布

在这一设定下,由于测试数据/分布对于训练过程是严格不可见/未知的,因此必然需要对数据生成过程的结构性假设(Structural Assumption)作为前提。相反,如果不存在任何的数据假设,分布外泛化就是不可能实现的(No-free Lunch Theorem)。因此,首先需要说明的是,OOD 问题的研究目标不是消除所有的假设,而是 1)在合理的假设条件下如何最大化模型的泛化能力,以及 2)如何合理地添加/减少假设来保障模型处理某种分布偏移的能力。

1.2 图上的分布外泛化问题
一般的分布外泛化问题可以被简单描述为:当 时,如何设计有效的机器学习方法?这里沿用文献中常用的设定,假设数据分布由一个潜在的环境 控制,则在给定环境下数据的生成可以写为 . 对于 OOD 问题,训练数据和测试数据可以假定为从不同的环境中生成的,此时问题可以被进一步阐述为如何学习一个预测器 使得它在所有环境 下都表现良好。
特别地,对于图结构数据,输入数据还会包含结构信息。此时,按照图结构存在的形式,可以大致分为两类问题:节点级别(node-level)的任务和图级别(graph-level)的任务。下图给出了分布外泛化问题的定义。

▲ 图分布外泛化问题的定义,其中按照图结构的存在形式对图级别和节点级别的任务作进一步的区分。特别地,对于节点级任务,由于图结构引入了节点样本间的非独立性,[5] 提出以节点中心子图为单位将数据集划分为互相独立的单元。

正如前文所述,OOD 问题需要对数据生成作一定的假设,再以此为基础构建可泛化的机器学习方法。下面将会介绍近期工作提出的两类具体方法,分别利用不变性原理和因果干预,来实现图数据的分布外泛化。


基于不变性原理的泛化
基于不变性原理(Invariance Principle)的学习方法,简称不变学习(Invariant Learning)[9, 10, 11],旨在通过设计新的学习算法来引导机器学习模型利用数据中的不变关系。这里的不变关系指在所有环境下都成立的输入 与标签 之间的关系。因此,当预测器 (例如神经网络)成功学习到这层不变关系后,它就能在分布不同的数据之间实现泛化。
与不变关系相对立的,被称为冗余相关性(Spurious Correlation),即仅在部分环境下成立的 的关系,这部分相关性会迷惑预测器在训练过程中一味地提升训练精度,最终发生过拟合。
基于以上的描述,可以看到,不变学习需要依赖数据生成的不变性假设,即在数据中存在 的预测关系在不同环境下保持不变。从数学上,可以描述为:存在映射 ,使得 满足 . 那么,在图结构数据上要如何定义不变性假设?以及,这种不变性假设对于图结构数据是否是合理的?
下面我们简单回顾论文 [5]:Handling Distribution Shifts on Graphs: An Invariance Perspective(ICLR 2022)。这个工作提出了将不变性原理应用到图结构数据,并重新定义了图上的不变性假设。

2.1 图上的不变性假设

为了将拓扑结构信息融入不变性原理,[5] 以图中每个节点为中心的子图为单位,考虑子图内部所有节点特征对中心节点标签的贡献,将后者具体分解为不变特征和冗余特征两个部分。这种定义方式受启发于图同构检测的 Weisfeiler Leman 算法,并且具有一定的灵活性。下图展示了 [5] 中定义的不变性假设的示意图和引用网络的例子。

▲ 图上的不变性假设(左)和引用网络的例子(右)。在引用网络中,每个节点是一篇论文,假设需要预测的标签(y)是论文的研究领域,节点标签包含论文的发表刊物(x1)和引用量(x2),环境(e)是论文发表的时间。在这一例子中,x1 就是不变特征,它与 y 的关系是与环境无关的;而 x2 就是冗余特征,它与 y 虽然存在很强的相关性,但这种相关性会随着时间发生变化。因此,在这个场景中,理想的预测器应该利用 x1 中的信息,才能实现在不同环境间的泛化

2.2 学习算法:探索-外推风险最小化
在不变性假设的前提下,一种最直接的思路就是约束模型在不同环境间损失的差异,以帮助模型学习到在不同环境间保持不变的预测关系。然而,现实的输入数据中通常没有环境标签(即每个样本与环境的对应关系未知),这导致无法直接计算不同环境间损失的差异。
为此,[5] 提出了探索-外推风险最小化(Explore-to-Extrapolate Risk Minimization,简称 EERM),通过引入 K 个环境生成器,对输入数据进行“多样化”的扩充,由此模拟不同环境下的输入样本。此外,[5] 证明了采用的目标函数可以在训练过程中使得模型趋向分布外泛化的最优解。

▲ [5] 提出的探索-外推风险最小化(EERM)算法:优化的内层目标是最大化 K 个环境生成器生成的样本的“多样性”;外层目标是以生成的 K 个虚拟环境下的数据计算损失的均值和方差,用于预测器的训练。

除了生成环境,另一项近期的研究 [12] 提出从观测数据中推断潜在的环境,并引入了一个额外的用于环境推断的模型,在训练过程中迭代地优化它与预测器。而 [13] 则从数据增强的角度出发,利用不变性原理引导生成的数据保持不变特征,从而更加有利于模型学习到可泛化的不变关系。


基于因果干预的泛化

不变学习需要假设数据中存在不变关系,并且这种不变关系可以从数据中学习得到。这也一定程度上限制了此类方法的适用性,即模型只能在与训练数据共享某种不变关系的测试数据上才能有泛化保障,而对于其他分布外的测试数据,模型的泛化性能如何则是未知的。

下面我们介绍近期工作 [14]:Graph Out-of-Distribution Generalization via Causal Intervention(WWW 2024)。这篇文章提出了另一种方法,从因果干预的角度实现分布外泛化。不同于不变学习,这种方法不依赖数据生成的不变性假设,而是通过训练算法引导模型学习从 的因果性。
3.1 图学习的因果性
首先,我们来看一下一般的机器学习模型(如图神经网络)所产生的变量间的因果关系。考虑输入 (例如图中以节点为中心的子图),预测的标签 和影响数据分布的环境 。当采用一般的监督学习目标(如经验风险最小化或极大似然估计)训练模型后,它们之间的依赖关系如下图所示。

▲ 因果图中包含三层依赖关系:1)从输入 到预测标签 ,这是由模型的前馈计算给定的;2)从环境 到输入 ,这是数据生成的假设给出的;3)从环境 到预测标签 ,这是模型的训练过程导致的(由于模型拟合了训练数据,而训练数据特定于某个环境)
以上的因果图也揭示了传统训练方法的局限性,即无法实现分布外泛化的原因。可以看到,这里的输入 和标签 都是环境 的共同结果,即它们被环境这一混淆因子(Confounder)给关联了起来。在训练过程中,模型会不断拟合训练数据,这就导致预测器 实际学到的是特定环境下的输入与标签的关系。

[14] 采用了一个社交网络的例子来理解这一学习过程。假设在社交网络中需要预测每个用户(节点)的兴趣爱好,而用户的兴趣爱好会与年龄、社交圈有很大的关系。因此如果一个模型在大学的社交网络数据上训练,就很容易预测一个用户具有“喜欢篮球”的爱好。这是因为在大学的用户群体里,“喜欢篮球”的用户占比较高,是“大学”这一环境导致了用户和“喜欢篮球”之间的强相关性。

然而当模型迁移到 Linkedin 的社交网络,这层预测关系就不再成立,因为 Linkedin 的用户年龄和兴趣爱好分布较为多样化。从这一例子可以看出,一个理想的模型需要学习到输入和标签之间的因果关系,才能在不同环境间实现泛化。

为了解决上述的问题,一种典型的思路就是采用因果干预 [15],即切断因果图中 的依赖关系,通过破坏环境对输入和标签的共同影响来引导模型学习因果性。
下图展示了这一思路的示意图。在因果推断中,这种切断指向某个变量依赖路径的操作,可以用 do 算子来表示。因此如果要在训练过程中考虑切断 的依赖关系,事实上就是把传统的优化目标 (即观测数据的概率似然)替换为 。进一步地,使用因果推断里的后门调整 [15],可以推导得出最终的目标函数形式:

▲ 基于因果干预的学习目标

然而,要计算这里的目标函数需要数据中存在环境信息,即每个样本 与环境 的对应关系。事实上,在大部分情况下,环境是无法被观测到的。

3.3 学习算法:变分环境调整

为了使得上述思路变得可行,[14] 推导了因果干预目标函数的变分下界,利用从观测数据中自动学习环境信息来解决环境信息缺失的问题。具体地,[14] 引入了一个变分分布 ,从而得到了下图所示的新的替代目标函数。

▲ 原始因果干预目标的变分下界和各项的具体实现

新的目标函数中包含了三项,[14] 将它们分别实例化为了环境推断器、GNN 预测器和(非参的)环境先验分布。其中,前两个模型都包含可训练的参数,会在训练过程中进行联合优化。
[14] 也通过实验验证了提出方法的有效性。特别地,因为提出的方法 CaNet 不依赖于具体的模型主干,[14] 考虑使用 GCN 和 GAT 分别作为主干并和其他 OOD 方法(包括前文提到的 EERM)作了对比。下表展示了部分对比结果。

▲ Arxiv 和 Twitch 数据集的分布外泛化实验结果。其中 Arxiv 使用论文发表时间进行数据划分,Twitch 则使用不同的子图进行划分

3.3 因果干预的隐含假设

至此,我们介绍了基于因果干预的分布外泛化方法。正如前文提到的,分布外泛化都需要对数据生成的假设才能取得有保障的泛化性能。那么,对于因果干预这类方法,在建模过程中用到了哪些假设呢?

事实上,和不变学习的求解思路不同的是,因果干预方法并没有从显式的假设出发,而是在建模和分析的过程中使用了隐含的假设:在输入和标签之间,只存在一个混淆因子(也就是环境)。这一假设在一定程度上简化了分析,但也引入了近似误差。对于更复杂的场景,未来还有很大的探索空间。


含隐式图结构的泛化

在前面的讨论中,我们假设输入数据的结构信息都是已被观测且完备的。对于更一般的图数据,结构信息可能只被部分观测到,甚至完全是未知的。我们称这类数据为隐式图结构数据。另一方面,图结构数据的分布偏移会涉及到数据背后的底层结构,如何刻画几何结构对数据分布的影响也是一个尚待解决的问题。

为此,近期的工作 [16]:Learning Divergence Fields for Shift-Robust Graph Representations(ICML 2024),从扩散方程与消息传递机制的内在联系出发,结合前文介绍的因果干预思路,设计了一种同时适用于显式和隐式结构分布外泛化的学习方法。

4.1 从消息传递到扩散方程

消息传递(Message Passing)机制是图神经网络和图 Transformer 的共用底层设计,即在每层中传播其他节点的信息来更新当前节点的表征。本质上,如果将神经网络的层看作连续时间的近似,那么消息传递可以看作图上扩散方程的离散形式 [17, 18]。下图通过类比的方式展示了二者的联系。

▲ 消息传递(即 GNN 和 Transformer 的层间更新)可以看作连续扩散方程的离散迭代:图中的节点视为流形上的位置,节点的表征视为热信号,表征随层数加深的更新视为热信号随时间的变化,每层节点间的相互作用视为流形上不同位置间的交互

特别地,扩散方程中的扩散系数(Diffusivity) 控制了扩散过程中节点间的相互作用。当采用局部或全局扩散形式时,扩散方程的离散迭代会分别导出图神经网络 [18] 和 Transformer [19] 的层间更新公式。
这种确定的扩散系数设计无法建模样本间相互作用关系的多重效应和不确定性。因此,[16] 提出将扩散系数 定义为从概率分布中采样得到的随机样本,对应的扩散方程会输出一条随机的连续变化轨迹(如下图所示)。

▲ 将扩散系数 定义为随机变量后,扩散方程每一时刻的散度场(即节点表征在当前层的变化量)将会变为随机的,这可以建模节点间相互作用的不确定性
然而,如果直接采用传统监督学习的目标函数进行训练,上述模型并不能很好的泛化。原因类似于前文介绍的图学习的因果性。具体地,在这里考虑的扩散模型中,输入 (如一张图)和输出 (如图中节点的标签)都被扩散系数 共同关联了起来,扩散系数可以视为一种特定于数据集的环境,它决定了样本间的相互依赖关系。因此在有限训练数据上完成训练的模型会倾向于学习到特定于训练集的相互作用关系,而无法泛化到新的测试数据。

4.2 因果引导的散度场学习

为了解决上述的挑战,我们再一次使用因果干预方法,在训练过程中切断扩散系数 和输入 之间的依赖关系。不同于之前工作 [14] 中输入到输出的映射关系被一个预测器给定,这里从 需要经历一个扩散过程的多步迭代(对应于 GNN/Transformer 中的多层更新)。因此,我们需要对扩散过程的每一步都作因果干预。
然而,由于扩散系数是建模中的一种抽象表示,它在实际中是无法被观测的(类似于前文提到的环境)。为此,[16] 拓展了 [14] 中采用的变分推断方法,为扩散过程对应的目标函数推导了一个变分下界,作为对扩散过程每一步进行因果干预的近似目标。

▲ [16] 提出的学习算法,为扩散模型的每一步估计扩散系数,并进行因果干预,它可以引导模型学习到从输入到输出的稳定因果关系,从而提升在不同数据分布间的泛化能力

作为上述方法的实现,[16] 提出了三种具体的模型设计:

  • GLIND-GCN:扩散系数为常数矩阵,采用归一化的图邻接矩阵;

  • GLIND-GAT:扩散系数为时变矩阵,采用图注意力网络实现;

  • GLIND-Trans:扩散系数是时变矩阵,采用全局注意力网络实现。

对于 GLIND-Trans,为了解决全局注意力计算的平方复杂度问题,[16] 进一步采用了 DIFFormer [19] 中的线性注意力函数设计。下表展示了在含隐式结构的场景下的部分实验对比结果。

▲ 含未知结构数据的分布外泛化实验结果。其中使用 KNN 来构造样本间的依赖结构,通过不同 K 值与叠加不同的向量旋转角度来引入分布偏移



总结与讨论

本文简单介绍了近期关于分布外泛化的研究,主要内容以三个发表的工作 [5, 14, 16] 为基础,分别从不变学习和因果干预的角度出发,提出的方法适用于显式和隐式图结构。正如前文所述,研究 OOD 问题通常需要对数据的假设。基于此,未来的研究可能可以关注在沿着现有的假设精进目前的方法或分析模型泛化的极限,也可以探索其他假设条件下如何实现泛化。

另一个与分布外泛化高度相关的研究问题是分布外检测(OOD Detection)[20, 21, 22]。与OOD 泛化不同的是,OOD 检测旨在研究如何在训练过程中使得模型具备识别测试阶段出现的分布外样本的能力。未来的研究也可以侧重于探索两类问题的联系。

参考文献

[1] Garg et al., Generalization and Representational Limits of Graph Neural Networks, ICLR 2020.
[2] Koh et al., WILDS: A Benchmark of in-the-Wild Distribution Shifts, ICML 2021
[3] Morris et al., Position: Future Directions in the Theory of Graph Machine Learning, ICML 2024.
[4] Zhu et al., Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data, NeurIPS 2021.
[5] Wu et al., Handling Distribution Shifts on Graphs: An Invariance Perspective, ICLR 2022.
[6] Li et al., OOD-GNN: Out-of-Distribution Generalized Graph Neural Network, TKDE 2022.
[7] Yehudai et al., From Local Structures to Size Generalization in Graph Neural Networks, ICML 2021.
[8] Li et al., Size Generalization of Graph Neural Networks on Biological Data: Insights and Practices from the Spectral Perspective, Arxiv 2024.
[9] Arjovsky, et al., Invariant Risk Minimization, Arxiv 2019.
[10] Rojas-Carulla, et al., Invariant Models for Causal Transfer Learning, JMLR 2018.
[11] Krueger et al., Out-of-Distribution Generalization via Risk Extrapolation, ICML 2021.
[12] Yang et al., Learning Substructure Invariance for Out-of-Distribution Molecular Representations, NeurIPS 2022.
[13] Sui et al., Unleashing the Power of Graph Data Augmentation on Covariate Distribution Shift, NeurIPS 2023.
[14] Wu et al., Graph Out-of-Distribution Generalization via Causal Intervention, WWW 2024.
[15] Pearl et al.,  Causal Inference in Statistics: A Primer, 2016.
[16] Wu et al., Learning Divergence Fields for Shift-Robust Graph Representations, ICML 2024.
[17] Freidlin et al., Diffusion Processes on Graphs and the Averaging Principle, The Annals of probability 1993.
[18] Chamberlain et al., GRAND: Graph Neural Diffusion, ICML 2021.
[19] Wu et al., DIFFormer: Scalable (Graph) Transformers Induced by Energy Constrained Diffusion, ICLR 2023.
[20] Wu et al., Energy-based Out-of-Distribution Detection for Graph Neural Networks, ICLR 2023.
[21] Liu et al., GOOD-D: On Unsupervised Graph Out-Of-Distribution Detection, WSDM 2023.

[22] Bao et al., Graph Out-of-Distribution Detection Goes Neighborhood Shaping, ICML 2024.



更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·
·
·

继续滑动看下一个
PaperWeekly
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存